import os.path
import glob
from .listdataset import ListDataset
from .util import split2list


def make_dataset(dir, split=None):
    '''Will search for triplets that go by the pattern '[name]_img1.ppm  [name]_img2.ppm    [name]_flow.flo' '''
    images = []
    for flow_map in sorted(glob.glob(os.path.join(dir,'*_flow.flo'))):
        flow_map = os.path.basename(flow_map)
        root_filename = flow_map[:-9]
        img1 = root_filename+'_img1.ppm'
        img2 = root_filename+'_img2.ppm'
        if not (os.path.isfile(os.path.join(dir,img1)) and os.path.isfile(os.path.join(dir,img2))):
            continue

        images.append([[img1,img2],flow_map])

    return split2list(images, split, default_split=0.97)


def flying_chairs(root, transform=None, target_transform=None,
                  co_transform=None, split=None):
    if split.isdigit():
        train_list, test_list = make_dataset(root,split)
    else:
        train_list, test_list = load_split_file(root, split)
    train_dataset = ListDataset(root, train_list, transform, target_transform, co_transform)
    test_dataset = ListDataset(root, test_list, transform, target_transform)

    return train_dataset, test_dataset


def load_split_file(root,split):
    file_list = open(os.path.join(split, "FlyingChairs/filelist.txt")).readlines()
    split_list = open(os.path.join(split, "FlyingChairs/train_val_split.txt")).readlines()
    train_list = []
    test_list = []
    for file, ty in zip(file_list, split_list):
        file = file[:-1]
        ty = ty[:-1]
        img1 = os.path.join(root, file.split("\t")[0])
        img2 = os.path.join(root, file.split("\t")[1])
        flo = os.path.join(root, file.split("\t")[2])
        if not (os.path.isfile(img1) and os.path.isfile(img2) and os.path.isfile(flo)):
            continue
        if ty == "1":
            train_list.append([[img1,img2], flo])
        if ty == "2":
            test_list.append([[img1,img2], flo])
    return train_list, test_list


